# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions
# are met:
#  * Redistributions of source code must retain the above copyright
#    notice, this list of conditions and the following disclaimer.
#  * Redistributions in binary form must reproduce the above copyright
#    notice, this list of conditions and the following disclaimer in the
#    documentation and/or other materials provided with the distribution.
#  * Neither the name of NVIDIA CORPORATION nor the names of its
#    contributors may be used to endorse or promote products derived
#    from this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
# PURPOSE ARE DISCLAIMED.  IN NO EVENT SHALL THE COPYRIGHT OWNER OR
# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
import json
from typing import Dict, Iterable, List, Optional, Required, TypedDict, Union, cast

# FIXME: Converge on single set of types in either schemas.openai or openai.types
from openai.types.chat import ChatCompletionMessageToolCallParam
from openai.types.chat.chat_completion_message_tool_call_param import Function
from schemas.openai import (
    ChatCompletionMessageToolCall,
    ChatCompletionRequestAssistantMessage,
    ChatCompletionRequestMessage,
    ChatCompletionRequestMessageContentPart,
    ChatCompletionRequestToolMessage,
    ChatCompletionRequestUserMessage,
    Type1,
)


class ConversationMessage(TypedDict, total=False):
    role: Required[str]
    """The role of the message's author."""

    content: Union[Optional[str], List[Dict[str, str]]]
    """The contents of the message"""

    tool_call_id: Optional[str]
    """Tool call that this message is responding to."""

    name: Optional[str]
    """The name of the function to call"""

    tool_calls: Optional[Iterable[ChatCompletionMessageToolCallParam]]
    """The tool calls generated by the model, such as function calls."""


def _frontend_schema_to_openai_schema_completion_tool_call(
    tool_call_param: ChatCompletionMessageToolCall,
) -> ChatCompletionMessageToolCallParam:
    return ChatCompletionMessageToolCallParam(
        id=tool_call_param.id,
        type=tool_call_param.type,
        function=Function(
            name=tool_call_param.function.name,
            arguments=tool_call_param.function.arguments,
        ),
    )


def _parse_chat_message_content_parts(
    role: str, parts: List[ChatCompletionRequestMessageContentPart]
) -> ConversationMessage:
    content = list[Dict]()

    for part in parts:
        if part.root.type == Type1.text or part.root.type == "text":
            parse_res = {"type": "text", "text": part.root.text}
            content.append(parse_res)
        else:
            raise TypeError(f"only text message is supported, but got {part.root.type}")

    return ConversationMessage(role=role, content=content)


def _parse_chat_message_content(
    message: ChatCompletionRequestMessage,
) -> ConversationMessage:
    role = message.root.role
    content = message.root.content

    if content is None or isinstance(content, str):
        result_msg = ConversationMessage(role=role, content=content)
    else:  # content is a list of message parts
        result_msg = _parse_chat_message_content_parts(
            role,
            content,
        )

    if role == "assistant":
        parsed_msg = cast(ChatCompletionRequestAssistantMessage, message.root)

        if parsed_msg.tool_calls:
            result_msg["tool_calls"] = list(
                [
                    _frontend_schema_to_openai_schema_completion_tool_call(tool_call)
                    for tool_call in parsed_msg.tool_calls.root
                ]
            )
    elif role == "tool":
        parsed_msg = cast(ChatCompletionRequestToolMessage, message.root)
        if parsed_msg.tool_call_id:
            result_msg["tool_call_id"] = parsed_msg.tool_call_id

    if isinstance(message.root, ChatCompletionRequestUserMessage) and isinstance(
        message.root.name, str
    ):
        result_msg["name"] = message.root.name

    return result_msg


def _postprocess_messages(messages: List[ConversationMessage]) -> None:
    # per the Transformers docs & maintainers, tool call arguments in
    # assistant-role messages with tool_calls need to be dicts not JSON str -
    # this is how tool-use chat templates will expect them moving forwards
    # so, for messages that have tool_calls, parse the string (which we get
    # from openAI format) to dict
    for message in messages:
        if (
            message["role"] == "assistant"
            and "tool_calls" in message
            and isinstance(message["tool_calls"], list)
        ):
            for item in message["tool_calls"]:
                item["function"]["arguments"] = json.loads(
                    item["function"]["arguments"]
                )


def parse_chat_messages(
    messages: List[ChatCompletionRequestMessage],
) -> List[ConversationMessage]:
    conversation: List[ConversationMessage] = []

    for msg in messages:
        sub_message = _parse_chat_message_content(msg)
        conversation.append(sub_message)

    _postprocess_messages(conversation)

    return conversation


# This function loads the chat template file content
# if the user chooses to use a chat template different from
# the original one provided with the model's tokenizer.
def load_chat_template(chat_template) -> Optional[str]:
    if chat_template is None:
        return None

    with open(chat_template) as f:
        return f.read()
